{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ba0b121f",
   "metadata": {},
   "source": [
    "# Dendritic Spine Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d9db864",
   "metadata": {},
   "outputs": [],
   "source": [
    "from spine_metrics import SpineMetricDataset\n",
    "from notebook_widgets import SpineMeshDataset, intersection_ratios_mean_distance, create_dir\n",
    "from spine_segmentation import apply_scale\n",
    "from spine_fitter import SpineGrouping\n",
    "from spine_clusterization import SpineClusterizer, DBSCANSpineClusterizer\n",
    "import numpy as np\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from sklearn.metrics import silhouette_score\n",
    "from typing import Optional\n",
    "from scipy.spatial.distance import jensenshannon\n",
    "import warnings\n",
    "warnings.filterwarnings(action='ignore', category=FutureWarning)\n",
    "\n",
    "\n",
    "dataset_path = \"0.025 0.025 0.1 dataset\"\n",
    "scale = (1, 1, 1)\n",
    "show_reduction_method = \"tsne\"\n",
    "    \n",
    "# load meshes and apply scale\n",
    "spine_dataset = SpineMeshDataset().load(dataset_path)\n",
    "spine_dataset.apply_scale(scale)\n",
    "\n",
    "# load merged and reduced manual classification\n",
    "manual_classification = SpineGrouping().load(f\"{dataset_path}/manual_classification/manual_classification_merged_reduced.json\")\n",
    "manual_classification = manual_classification.get_spines_subset(spine_dataset.spine_names)\n",
    "\n",
    "# load metrics\n",
    "spine_metrics = SpineMetricDataset().load(f\"{dataset_path}/metrics.csv\")\n",
    "spine_metrics = spine_metrics.get_spines_subset(manual_classification.samples)\n",
    "\n",
    "# extract metric subsets\n",
    "classic = spine_metrics.get_metrics_subset(['OpenAngle', 'CVD', \"JunctionArea\", 'AverageDistance', 'Length', 'Area', 'Volume', 'ConvexHullVolume', 'ConvexHullRatio', \"LengthVolumeRatio\", \"LengthAreaRatio\"])\n",
    "chord = spine_metrics.get_metrics_subset(['OldChordDistribution'])\n",
    "\n",
    "# set score function to mean distance between class over cluster distributions\n",
    "score_func = lambda clusterizer: intersection_ratios_mean_distance(manual_classification, clusterizer.grouping, False)\n",
    "\n",
    "\n",
    "# prepare folders for export\n",
    "create_dir(f\"{dataset_path}/clustering\")\n",
    "classic_save_path = f\"{dataset_path}/clustering/classic\"\n",
    "create_dir(classic_save_path)\n",
    "chord_save_path = f\"{dataset_path}/clustering/chord/euclidean\"\n",
    "create_dir(f\"{dataset_path}/clustering/chord\")\n",
    "create_dir(f\"{dataset_path}/clustering/chord/euclidean\")\n",
    "chord_js_save_path = f\"{dataset_path}/clustering/chord/jensen-shannon\"\n",
    "create_dir(f\"{dataset_path}/clustering/chord/jensen-shannon\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fece41ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# elbow method\n",
    "def kmeans_elbow_score(clusterizer: SpineClusterizer) -> float:\n",
    "    # sum of mean distances to cluster center\n",
    "    output = 0\n",
    "    for group in clusterizer.grouping.groups.values():\n",
    "        center = sum(clusterizer.fit_metrics.row_as_array(spine_name) for spine_name in group)\n",
    "        output += sum(np.inner(center - clusterizer.fit_metrics.row_as_array(spine_name),\n",
    "                               center - clusterizer.fit_metrics.row_as_array(spine_name)) for spine_name in group)\n",
    "    return output\n",
    "\n",
    "\n",
    "def dbscan_elbow_score(clusterizer: DBSCANSpineClusterizer) -> float:\n",
    "    # number of points with not enough neighbours close enough to form a cluster\n",
    "    neigh = NearestNeighbors(n_neighbors=clusterizer.min_samples, metric=clusterizer.metric)\n",
    "    data = clusterizer.fit_metrics.as_array()\n",
    "    nbrs = neigh.fit(data)\n",
    "    distances, indices = nbrs.kneighbors(data)\n",
    "    # get distances to closest k-th neighbour\n",
    "    distances = distances[:, -1]\n",
    "    # sort distances in descending order\n",
    "    distances = -np.sort(-distances, axis=0)\n",
    "    for i in range(len(distances)):\n",
    "        if clusterizer.eps > distances[i]:\n",
    "            return i\n",
    "    return len(distances)\n",
    "\n",
    "def silhouette(clusterizer: SpineClusterizer, metric: Optional[callable] = None) -> float:\n",
    "    datas = []\n",
    "    labels = []\n",
    "    for i, group in enumerate(clusterizer.grouping.groups.values()):\n",
    "        datas.extend(clusterizer.fit_metrics.row_as_array(spine) for spine in group)\n",
    "        labels.extend([i for _ in group])\n",
    "    \n",
    "    labels = np.array(labels)\n",
    "    if metric is None:\n",
    "        score = silhouette_score(datas, labels, metric=clusterizer.metric)\n",
    "    else:\n",
    "        score = silhouette_score(np.array([[metric(x1, x2) for x1 in datas] for x2 in datas]), labels, metric=\"precomputed\")\n",
    "    return score\n",
    "\n",
    "def js_distance(x, y) -> float:\n",
    "    return np.sqrt(jensenshannon(x, y))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ec61ce3",
   "metadata": {},
   "source": [
    "## k-Means Classic Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c48863",
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebook_widgets import k_means_clustering_experiment_widget\n",
    "\n",
    "# score_func = lambda clusterizer: intersection_ratios_mean_distance(manual_classification, clusterizer.grouping, False)\n",
    "#score_func = silhouette\n",
    "score_func = kmeans_elbow_score\n",
    "\n",
    "dim_reduction = \"\"\n",
    "\n",
    "display(k_means_clustering_experiment_widget(classic, spine_metrics, spine_dataset, score_func,\n",
    "                                             max_num_of_clusters=100, classification=manual_classification,\n",
    "                                             save_folder=classic_save_path, dim_reduction=dim_reduction, show_method=show_reduction_method))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70c72373",
   "metadata": {},
   "source": [
    "## k-Means Chord Histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3570bfd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebook_widgets import k_means_clustering_experiment_widget\n",
    "\n",
    "# score_func = lambda clusterizer: intersection_ratios_mean_distance(manual_classification, clusterizer.grouping, False)\n",
    "score_func = kmeans_elbow_score\n",
    "\n",
    "display(k_means_clustering_experiment_widget(chord, spine_metrics, spine_dataset, score_func,\n",
    "                                             max_num_of_clusters=100, classification=manual_classification,\n",
    "                                             save_folder=chord_save_path, dim_reduction=\"\", show_method=show_reduction_method))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3ea493e",
   "metadata": {},
   "source": [
    "## DBSCAN Classic Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2610a7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebook_widgets import dbscan_clustering_experiment_widget\n",
    "\n",
    "min_eps = 0.2\n",
    "max_eps = 6\n",
    "eps_step = 0.1\n",
    "\n",
    "# score_func = lambda clusterizer: intersection_ratios_mean_distance(manual_classification, clusterizer.grouping, False)\n",
    "score_func = dbscan_elbow_score\n",
    "\n",
    "display(dbscan_clustering_experiment_widget(classic, spine_metrics, spine_dataset, score_func,\n",
    "                                            min_eps=min_eps, max_eps=max_eps, eps_step=eps_step, dim_reduction=\"pca\", show_method=show_reduction_method,\n",
    "                                            classification=manual_classification, save_folder=classic_save_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce88febb",
   "metadata": {},
   "source": [
    "## DBSCAN Chord Histograms Euclidean Distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ea55f30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebook_widgets import dbscan_clustering_experiment_widget\n",
    "\n",
    "min_eps = 0.1\n",
    "max_eps = 10\n",
    "eps_step = 0.1\n",
    "\n",
    "# score_func = lambda clusterizer: intersection_ratios_mean_distance(manual_classification, clusterizer.grouping, False)\n",
    "score_func = dbscan_elbow_score\n",
    "\n",
    "display(dbscan_clustering_experiment_widget(chord, spine_metrics, spine_dataset, score_func,\n",
    "                                            min_eps=min_eps, max_eps=max_eps, eps_step=eps_step, dim_reduction=\"\", show_method=show_reduction_method,\n",
    "                                            classification=manual_classification, save_folder=chord_save_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44620329",
   "metadata": {},
   "source": [
    "## DBSCAN Chord Histograms Jensen — Shannon Distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "434345f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebook_widgets import dbscan_clustering_experiment_widget\n",
    "from scipy.spatial.distance import jensenshannon\n",
    "import numpy as np\n",
    "\n",
    "min_eps = 0.1\n",
    "max_eps = 1\n",
    "eps_step = 0.01\n",
    "use_pca = False\n",
    "\n",
    "# score_func = lambda clusterizer: intersection_ratios_mean_distance(manual_classification, clusterizer.grouping, False)\n",
    "score_func = dbscan_elbow_score\n",
    "\n",
    "def js_distance(x, y) -> float:\n",
    "    return np.sqrt(jensenshannon(x, y))\n",
    "\n",
    "display(dbscan_clustering_experiment_widget(chord, spine_metrics, spine_dataset, score_func, metric=js_distance,\n",
    "                                            min_eps=min_eps, max_eps=max_eps, eps_step=eps_step, use_pca=use_pca,\n",
    "                                            classification=manual_classification, save_folder=chord_js_save_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8b188ec",
   "metadata": {},
   "source": [
    "## View clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e64e99db",
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebook_widgets import inspect_saved_groupings_widget\n",
    "\n",
    "display(inspect_saved_groupings_widget(f\"{dataset_path}/clustering\", spine_dataset, spine_metrics,\n",
    "                                       chord, classic, manual_classification))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
